{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# livelossplot example: PyTorch Ignite\n",
    "\n",
    "Last update: `livelossplot 0.5.0`. For code and documentation, see [livelossplot GitHub repository](https://github.com/stared/livelossplot).\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/stared/livelossplot/blob/master/examples/pytorch-ignite.ipynb\" target=\"_parent\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/>\n",
    "</a>\n",
    "\n",
    "## Note\n",
    "\n",
    "The syntax changed, and from `0.5+` on it is):\n",
    "\n",
    "```{python}\n",
    "from livelossplot import PlotLossesIgnite\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!pip install pytorch-ignite livelossplot --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "from torch import nn\n",
    "import torch.optim\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.transforms import Compose, ToTensor, Normalize\n",
    "from torchvision.datasets import MNIST\n",
    "\n",
    "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n",
    "from ignite.metrics import Accuracy, Loss\n",
    "from ignite.contrib.handlers import ProgressBar\n",
    "\n",
    "from livelossplot import PlotLossesIgnite "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "TRAIN_BATCH_SIZE = 100\n",
    "TEST_BATCH_SIZE = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
    "        self.conv2_drop = nn.Dropout2d()\n",
    "        self.fc1 = nn.Linear(320, 50)\n",
    "        self.fc2 = nn.Linear(50, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        x = x.view(-1, 320)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "def get_data_loaders(train_batch_size, val_batch_size):\n",
    "    data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])\n",
    "\n",
    "    train_loader = DataLoader(\n",
    "        MNIST(download=True, root=\".\", transform=data_transform, train=True), batch_size=train_batch_size, shuffle=True\n",
    "    )\n",
    "\n",
    "    val_loader = DataLoader(\n",
    "        MNIST(download=False, root=\".\", transform=data_transform, train=False), batch_size=val_batch_size, shuffle=False\n",
    "    )\n",
    "    return train_loader, val_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "model = Net()\n",
    "train_loader, val_loader = get_data_loaders(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)\n",
    "loss = torch.nn.NLLLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "trainer = create_supervised_trainer(model, optimizer, loss)\n",
    "metrics = metrics={\n",
    "    'val_acc': Accuracy(),\n",
    "    'val_loss': Loss(loss)\n",
    "}\n",
    "p_bar = ProgressBar()\n",
    "p_bar.attach(trainer)\n",
    "evaluator = create_supervised_evaluator(model, metrics=metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "@trainer.on(Events.EPOCH_COMPLETED)\n",
    "def validate(trainer):\n",
    "    evaluator.run(val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "callback = PlotLossesIgnite()\n",
    "callback.attach(evaluator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "trainer.run(train_loader, max_epochs=100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}